import openmc
import numpy as np
from nuscale.surfaces import assembly_pitch, surfs, pin_cell_pitch

def mesh_tallies(fa_radial=True, pin_radial=False, axial=False):
    bottom_fuel = surfs['fuel bottom'].z0
    top_fuel = surfs['fuel top'].z0
    tallies = []
    if fa_radial:
        mesh = openmc.RectilinearMesh(name="2D mesh 7x7")

        mesh.x_grid = np.linspace(-7. * assembly_pitch / 2., 7. * assembly_pitch / 2., 8, endpoint=True)
        mesh.y_grid = np.linspace(-7. * assembly_pitch / 2., 7. * assembly_pitch / 2., 8, endpoint=True)
        mesh.z_grid = np.array([bottom_fuel, top_fuel])

        mesh_filter = openmc.MeshFilter(mesh=mesh)

        mesh_tally = openmc.Tally(name='FA Fission Rates fuel assembly')
        mesh_tally.filters = [mesh_filter]
        mesh_tally.scores = ['fission']
        tallies.append(mesh_tally)
    if pin_radial:
        mesh = openmc.RectilinearMesh(name="2D mesh 119x119")
        mesh.x_grid = np.linspace(-7. * assembly_pitch / 2., 7. * assembly_pitch / 2., 120, endpoint=True)
        mesh.y_grid = np.linspace(-7. * assembly_pitch / 2., 7. * assembly_pitch / 2., 120, endpoint=True)
        mesh.z_grid = np.array([bottom_fuel, top_fuel])

        mesh_filter = openmc.MeshFilter(mesh=mesh)

        mesh_tally = openmc.Tally(name='FA Fission Rates pin-by-pin')
        mesh_tally.filters = [mesh_filter]
        mesh_tally.scores = ['fission']
        tallies.append(mesh_tally)
    if axial:
        mesh = openmc.RectilinearMesh(name="Axial mesh")
        mesh.x_grid = np.array([-7. * assembly_pitch / 2., 7. * assembly_pitch / 2.])
        mesh.y_grid = np.array([-7. * assembly_pitch / 2., 7. * assembly_pitch / 2.])
        mesh.z_grid = np.array([
            11.365,
            14.920,
            19.365,
            31.017,
            42.670,
            54.322,
            65.974,
            70.419,
            82.071,
            93.724,
            105.376,
            117.028,
            121.473,
            133.125,
            144.778,
            156.430,
            168.082,
            172.527,
            182.236,
            191.946,
            201.656,
            211.365
        ])

        mesh_filter = openmc.MeshFilter(mesh=mesh)

        mesh_tally = openmc.Tally(name='Axial fission rates')
        mesh_tally.filters = [mesh_filter]
        mesh_tally.scores = ['fission']
        tallies.append(mesh_tally)

    return openmc.Tallies(tallies)
